"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


This script computes the statistics of the flow stability clustering of the APS
dataset and creates Fig. 8.

"""
import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))



import numpy as np
import pandas as pd
from TemporalNetwork import ContTempNetwork
from TemporalStability import Partition
                              


import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap


import matplotlib as mpl

from collections import Counter


color_list =["#ffffff",
"#4ba706",
"#a2007e",
"#806dcb",
"#5eb275",
"#ca3b01",
"#01a4d6",
"#b77600",
"#a39643",
"#cc6ea9",
"#1e5e39",
"#cb5b5a"]



import matplotlib as mpl
mpl.rcParams['lines.linewidth'] = 1.7


datadir = '../paper_data/aps/'

network = 'complenet'

clusterdir_fwd = os.path.join(datadir, f'clustersI_{network}_forward')

figdir = '../figures/aps/complenet'

file_prefix = 'aps_monthly'

net_file = os.path.join(datadir,'aps_monthly_lcc_net')

raise Exception

#%%


files = os.listdir(clusterdir_fwd)

interval_starts = set()
interval_ends = set()
tau_ws = set()
for file in files:
    if file.startswith('clusters'):
        extracts = os.path.splitext(file)[0].split('_')
        interval_starts.add(int(extracts[-3]))
        interval_ends.add(int(extracts[-1]))
        for extract in os.path.splitext(file)[0].split('_'):
            if extract.startswith('w'):
                tau_ws.add(float(extract[1:]))
                
                
            
#%%% net work file

net = ContTempNetwork.load(net_file, attributes_list = ['node_to_label_dict',
                          'events_table',
                          'times',
                          'time_grid',
                          'num_nodes',
                          'time_slices_bounds',
                          'time_slices_bounds_datetimes'])


    

time_slices = net.time_slices_bounds

label_to_node_dict = {l : n for n, l in net.node_to_label_dict.items()}
tau_ws = sorted(list(tau_ws))


#%% size of the network per year

active_nodes_per_year = dict()

for i, (start, end) in enumerate(zip(time_slices[:-1],time_slices[1:])):
    print(i, start, end)

    source_nodes = net.events_table.loc[np.logical_and(net.events_table.starting_times < end, net.events_table.ending_times >= start)].source_nodes.unique()
    target_nodes = net.events_table.loc[np.logical_and(net.events_table.starting_times < end, net.events_table.ending_times >= start)].target_nodes.unique()

    active_nodes_per_year[i] = set(source_nodes) | set(target_nodes)
    
active_nodes_per_year_cummul = dict()
nodes = set()
for i, node_set in active_nodes_per_year.items():
    
    nodes.update(node_set)
    active_nodes_per_year_cummul[i] = nodes.copy()
    
active_nodes_per_decade = {}

decades_int = {n:r for n,r in  zip(['70s','80s','90s','00s'],[range(78,88),range(88,98),range(98,108),range(108,118)])}
nodes = set()
for dec, r in decades_int.items():
    nodes = set()
    for i in r:
        nodes.update(active_nodes_per_year[i])
    
    active_nodes_per_decade[dec] = nodes.copy()
#%% 

multi_res = pd.read_pickle(os.path.join(datadir, 'multi_complenet_partitions_all_comps.pickle'))

multi_res = {'complenet': multi_res}

networks = ['complenet']
directions = ['forw'] # this is slightly confusing, but here 'forw' means backward in the paper, i.e. clustering in the same direction than the diffusion process which is going backward in time
#%% intersect partitions with active nodes

new_multi_res = {}
for c in networks:
    new_multi_res[c] = {}
    for tau_w in tau_ws:
        new_multi_res[c][tau_w] = {}    
        for dire in directions:
            new_multi_res[c][tau_w]['partitions_' + dire] = []
            for part, dec in zip(multi_res[c][tau_w]['partitions_' + dire],
                                       ['00s','90s','80s','70s']):
                new_c_list = [cl & active_nodes_per_decade[dec] for cl in part.cluster_list]
                
                new_multi_res[c][tau_w]['partitions_' + dire].append(Partition(sum([len(c) for c in new_c_list]),
                                                                            new_c_list))

multi_res = new_multi_res
#%%


# list of journals in which each author published
df_auth_journal = pd.read_csv('../data/aps/df_author_journal.csv', index_col='author')

class_dict_journ = {journ : set([label_to_node_dict[l] for \
                               l in df.index.tolist() if l in label_to_node_dict]) for \
              journ, df in df_auth_journal.groupby('top_non_prl')}
class_dict_journ['other'] = set([label_to_node_dict[l] for l in \
                           df_auth_journal.loc[df_auth_journal.top_non_prl.isnull()].index.tolist()\
                               if l in label_to_node_dict])
    
node_to_journal_dict = {label_to_node_dict[a] : c for a,c in df_auth_journal.top_non_prl.iteritems()\
                                if a in label_to_node_dict}
    
# list of countries (in affiliations) for each author
df_auth_country = pd.read_csv('../data/aps/df_author_country.csv', index_col='author')
df_auth_country.top_country.loc[df_auth_country.top_country.isna()] = 'undefined'

class_dict_country = {count : set([label_to_node_dict[l] for \
                               l in df.index.tolist() if l in label_to_node_dict]) for \
              count, df in df_auth_country.groupby('top_country')}
    

node_to_country_dict = {label_to_node_dict[a] : c for a,c in df_auth_country.top_country.iteritems()\
                                if a in label_to_node_dict}

#%% clusters to country/journal


clusts = {}

for tau_w in sorted(tau_ws):
    clusts[tau_w] = {}
    for network in networks:
        clusts[tau_w][network] = {}
        for dire in directions:
            clusts[tau_w][network][dire] = {}
            for type in ['journal', 'country']:
                clusts[tau_w][network][dire][type] = []
                
                for part in multi_res[network][tau_w]['partitions_' + dire]:
                    
                    if type =='journal':
                        clusts[tau_w][network][dire][type].append([[node_to_journal_dict[n] for n in c] for c in part.cluster_list])
                    elif type =='country':
                        clusts[tau_w][network][dire][type].append([[node_to_country_dict[n] for n in c] for c in part.cluster_list])

    

#%% decades


decades= {'70s':{'forw':3},
         '80s':{'forw':2},
         '90s':{'forw':1},
         '00s':{'forw':0}}





#%% size distributions

tau_w = 3650
clust_sizes = {}
    
for network in networks:
    clust_sizes[network] = {}
    for dire in directions:        
        clust_sizes[network][dire] = {}
        for dec, dec_dict in decades.items():            
            
            
            if dire in dec_dict:
                clust_sizes[network][dire][dec] = np.array([len(c) for c in multi_res[network][tau_w]['partitions_' + dire][dec_dict[dire]].cluster_list])
            
            

dec = '00s'            
plt.figure()
h = plt.hist(clust_sizes[network]['forw'][dec],
             bins=np.logspace(0,np.log10(max(clust_sizes[network]['forw'][dec])+1),100),
             histtype='step', label='Forw')

plt.xscale('log')
plt.yscale('log')
plt.legend()
plt.title(f'Complenet {dec}, cluster size distribution, tau_w={tau_w}')
# plt.savefig(os.path.join(figdir,f'clusters_size_dist_Complenet_{dec}_tau_w{tau_w}.png'))
# plt.savefig(os.path.join(figdir,f'clusters_size_dist_Complenet_{dec}_tau_w{tau_w}.pdf'))



#%% sorted partitions
tau_w = 3650


sorted_parts = {}

for network in networks:
    sorted_parts[network] = {}
    for dire in directions:
        sorted_parts[network][dire] = {}
        for type in ['journal', 'country']:

            sorted_parts[network][dire][type] = [sorted(part, key=len, reverse=True) for part in clusts[tau_w][network][dire][type]]
                


#%% entropy dist
def compute_entropy(clust):
        
    p = np.array(list(Counter(clust).values()))
    p = p/p.sum()
    
    return -1*(p*np.log2(p)).sum()

def compute_norm_entropy(clust):
        
    p = np.array(list(Counter(clust).values()))
    p = p/p.sum()
    
    if len(clust)==1:
        return 0
    else:
        return -1*(p*np.log2(p)).sum()/np.log2(len(clust))
    
    
clust_entropies = {}
for network in networks:
    clust_entropies[network] = {}
    
    for dire in directions:
        clust_entropies[network][dire] = {}
        
        for type in ['journal','country']:
            clust_entropies[network][dire][type] = {}
            
            for dec, dec_dict in decades.items():
                if dire in dec_dict:
                    clust_entropies[network][dire][type][dec] = \
                        [compute_entropy(clust) for clust in sorted_parts[network][dire][type][dec_dict[dire]]]
                    

#%% total entropy plot

plt.figure()

decades_list = list(decades.keys())
dire = 'forw'
type='journal'
plt.plot(range(len(decades_list)), 
         [sum(clust_entropies[network][dire][type][dec]) for dec in decades_list],
         label='forw ' + type)
type='country'
plt.plot(range(len(decades_list)), 
         [sum(clust_entropies[network][dire][type][dec]) for dec in decades_list],
         label='forw ' + type)
plt.gca().set_xticks(range(len(decades_list)))
plt.gca().set_xticklabels(decades_list)


plt.legend()
plt.ylabel('total clustering entropy')

# plt.savefig(os.path.join(figdir,f'clusters_entropy_total_Complenet_tau_w{tau_w}.png'))
# plt.savefig(os.path.join(figdir,f'clusters_entropy_total_Complenet_tau_w{tau_w}.pdf'))
    
#%% KL dist compared to random

all_nodes = {}
for network in networks:
    all_nodes[network] = {}
    for dire in directions:
        all_nodes[network][dire] = {}
        for dec, dec_dict in decades.items():
            if dire in dec_dict:
                all_nodes[network][dire][dec] = [n for c in multi_res[network][tau_w]['partitions_' + dire][dec_dict[dire]].cluster_list for n in c]
 

all_nodes_pdfs = {}
for network in networks:
    all_nodes_pdfs[network] = {}
    for dire in ['forw', 'back']:
        all_nodes_pdfs[network][dire] = {}
        for type in ['country','journal']:
            all_nodes_pdfs[network][dire][type] = {}
            for dec, dec_dict in decades.items():
                if dire in dec_dict:
                    if type == 'country':
                        co = Counter([node_to_country_dict[n] for n in  all_nodes[network][dire][dec]])
                    elif type == 'journal':
                        co = Counter([node_to_journal_dict[n] for n in  all_nodes[network][dire][dec]])
                    all_nodes_pdfs[network][dire][type][dec] = {k : v/sum(co.values()) for k,v in co.items()}
                    
        
        
def KL_div(prob_dict, network, dire, type, dec):
    DKL = 0
    for k, q in all_nodes_pdfs[network][dire][type][dec].items():
        if k in prob_dict:
            DKL += prob_dict[k]*np.log2(prob_dict[k]/q)
        
    return DKL

#%% total KL div


clust_counters = dict()

total_KLs = {}
for network in networks:
    clust_counters[network] = {}
    total_KLs[network] = {}    
    for dire in directions:
        clust_counters[network][dire] = {}
        total_KLs[network][dire] = {}    
        for type in ['country','journal']:
            clust_counters[network][dire][type] = {}
            total_KLs[network][dire][type] = {}    
            for dec, dec_dict in decades.items():
                if dire in dec_dict:    

                    clust_counters[network][dire][type][dec] = \
                        [Counter(clust) for clust in sorted_parts[network][dire][type][dec_dict[dire]]]
                    
                    # weighted avergage of the KL div of each clust
                    total_KLs[network][dire][type][dec] = sum([KL_div({k : v/sum(co.values()) for k,v in co.items()},network, dire, type, dec)*sum(co.values()) for co in clust_counters[network][dire][type][dec]])/len(all_nodes[network][dire][dec])
            
        

#%% plot total KL


plt.figure()

decades_list = list(decades.keys())
dire = 'forw'
type='journal'
plt.plot(range(len(decades_list)), 
         [total_KLs[network][dire][type][dec] for dec in decades_list],
         label='forw ' + type)
type='country'
plt.plot(range(len(decades_list)), 
         [total_KLs[network][dire][type][dec] for dec in decades_list],
         label='forw ' + type)
plt.gca().set_xticks(range(len(decades_list)))
plt.gca().set_xticklabels(decades_list)



plt.legend()
plt.ylabel('total KL div')


#%%

# plt.savefig(os.path.join(figdir,f'KL_div_all_clusts_tau_w{tau_w}_{dec}.png'))
# plt.savefig(os.path.join(figdir,f'KL_div_all_clusts_tau_w{tau_w}_{dec}.pdf'))


#%% check who is in clusters

# this is the author name disambiguation of the APS dataset
# available at Supplementary Material at https://doi.org/10.1126/science.aaf5239
df_authors = pd.read_csv('../data/aps/APS_authors.csv', index_col=0)

sorted_part_forw00 = sorted(multi_res[network][tau_w]['partitions_forw'][0].cluster_list, key=len)[::-1]

# biggest forw 00s clust:
names=[(df_authors.iloc[net.node_to_label_dict[n]]['name'], 
        node_to_country_dict[n], 
        node_to_journal_dict[n]) for n in sorted_part_forw00[0]]


    
#%% paper plots
#%%% distribution of all decads in one plot
import matplotlib as mpl
# mpl.RcParams.update({\
#                      # 'font.size': 15,
#                      # 'axes.titlesize' : 5,
#                      # 'axes.labelsize' : 2,
#                      'xtick.labelsize' : 2,
#                      'ytick.labelsize' : 2,
#                      })

ticklabelsize=14
textsize=15
labelsize=15
legendsize=15   

fig, axes = plt.subplots(4,2, sharex='col', figsize=(8,7.9))
fig.subplots_adjust(hspace=0.00)
fig.subplots_adjust(wspace=0.00)

decades_list = ['00s', '90s', '80s', '70s']
max_size = max([max(clust_sizes[network]['forw'][dec]) for dec in decades_list])
                
for dec, ax in zip(decades_list, axes):

    h = ax[1].hist(clust_sizes[network]['forw'][dec],
             bins=np.logspace(0,np.log10(max_size+1),50),
             histtype='step', 
             color=color_list[2],
             label=dec)

    ax[1].set_xscale('log')
    ax[1].set_yscale('log')
    ylim=ax[1].get_ylim()
    ax[1].set_ylim([0.6,ylim[1]])
    # ax[1].legend()
    # ax[1].set_ylabel('counts')
    
    ax[1].text(0.7,0.7,"'" + dec, transform=ax[1].transAxes, fontsize=15)

axes[3][1].set_xlabel('community size', fontsize=labelsize)




# number of nodes
axes[0][0].plot(range(len(decades_list)), 
         [sum(clust_sizes[network]['forw'][dec]) for dec in decades_list],
         'o-', 
                color=color_list[3])
axes[0][0].set_xticks(range(len(decades_list)))
axes[0][0].set_xticklabels(decades_list)
axes[0][0].set_ylabel('num. nodes', fontsize=labelsize)

# number of communitites
axes[1][0].plot(range(len(decades_list)), 
         [len(clust_sizes[network]['forw'][dec]) for dec in decades_list],
         'o-', 
                color=color_list[3])
axes[1][0].set_xticks(range(len(decades_list)))
axes[1][0].set_xticklabels(decades_list)
axes[1][0].set_ylabel('num. communities', fontsize=labelsize)

ylim=axes[1][0].get_ylim()
axes[1][0].set_ylim([0,ylim[1]*1.05])

ylim=axes[0][0].get_ylim()
axes[0][0].set_ylim([0,ylim[1]*1.05])

dire = 'forw'
type='journal'
axes[3][0].plot(range(len(decades_list)),
         [total_KLs[network][dire][type][dec] for dec in decades_list],
         'o-', 
             color=color_list[4],
         label=type)
type='country'
axes[3][0].plot(range(len(decades_list)), 
         [total_KLs[network][dire][type][dec] for dec in decades_list],
         'o-', 
                color=color_list[3],
         label=type)
axes[3][0].set_xticks(range(len(decades_list)))
axes[3][0].set_xticklabels(decades_list)
axes[3][0].set_ylabel('KL div.', fontsize=labelsize)
axes[3][0].legend(fontsize=legendsize)

type='journal'
axes[2][0].plot(range(len(decades_list)), 
         [sum(clust_entropies[network][dire][type][dec]) for dec in decades_list],
         'o-', 
             color=color_list[4],
         label=type)
type='country'
axes[2][0].plot(range(len(decades_list)), 
         [sum(clust_entropies[network][dire][type][dec]) for dec in decades_list],
         'o-', 
                color=color_list[3],
         label=type)
axes[2][0].set_xticks(range(len(decades_list)))
axes[2][0].set_xticklabels(decades_list)
axes[2][0].set_ylabel('entropy', fontsize=labelsize)
# axes[1][1].legend()



ylim=axes[2][0].get_ylim()
axes[2][0].set_ylim([0,ylim[1]*1.05])

ylim=axes[3][0].get_ylim()
axes[3][0].set_ylim([0,ylim[1]*1.05])


## set x ticks
x_major = mpl.ticker.LogLocator(base = 10.0, numticks = 5)
x_minor = mpl.ticker.LogLocator(base = 10.0, subs = np.arange(1.0, 10.0) * 0.1, numticks = 10)
for ax in axes:
    ax[1].xaxis.set_major_locator(x_major)
    ax[1].yaxis.set_minor_locator(x_minor)
    ax[1].yaxis.set_major_locator(x_major)
    ax[1].xaxis.set_minor_locator(x_minor)

fig.align_ylabels([ax[0] for ax in axes])

fig.tight_layout()


for ax, letter in zip(axes, [r'\textbf{A}', r'\textbf{B}', r'\textbf{C}', r'\textbf{D}']):
    ax[0].text(-0.35,1.0, letter, transform=ax[0].transAxes, useTex=True)
# axes[3][1].yaxis.set_minor_formatter(mpl.ticker.NullFormatter())

for ax, letter in zip(axes, [r'\textbf{E}', r'\textbf{F}', r'\textbf{G}', r'\textbf{H}',]):
    ax[1].text(-0.22,1.0, letter, transform=ax[1].transAxes, useTex=True)
    
    
for ax in axes.flatten():
    ax.yaxis.set_tick_params(labelsize=ticklabelsize)    
    ax.xaxis.set_tick_params(labelsize=ticklabelsize)    
    
# plt.savefig(os.path.join(figdir,f'complenet_clusts_stats_tau_w{tau_w}.png'))
# plt.savefig(os.path.join(figdir,f'complenet_clusts_stats_tau_w{tau_w}.pdf'))    

#%% proportion of journals

total_counter_journals = Counter()
for c in clust_counters[network]['forw']['journal']['00s']:
    total_counter_journals += c

total_counter_journals['PhysRevB']/sum(total_counter_journals.values())

total_counter_countries = Counter()
for c in clust_counters[network]['forw']['country']['00s']:
    total_counter_countries += c

total_counter_countries['USA']/sum(total_counter_countries.values())
